Fix WeightAveraging swapping in the un-updated average model during validation#21732
Open
ATOM00blue wants to merge 2 commits into
Open
Conversation
…ation Before its first update, the AveragedModel only holds the copy of the initial weights made in setup(). The validation hooks swapped it in unconditionally, so during a delayed-start warmup (e.g. EMAWeightAveraging with update_starting_at_step) validation evaluated the untrained snapshot instead of the current trained weights. Only swap the models when the average model has been updated at least once (n_averaged > 0). The swap stays balanced across validation start/end since n_averaged does not change during validation.
for more information, see https://pre-commit.ci
Contributor
There was a problem hiding this comment.
Pull request overview
Fixes a bug in the WeightAveraging / EMAWeightAveraging callbacks where validation could swap in the averaged model even before it had received its first update (n_averaged == 0), causing validation to run on an untrained initial-weight snapshot during delayed-start warmup.
Changes:
- Guarded validation-time model swapping so it only happens after the averaged model has been updated at least once (
n_averaged > 0). - Updated SWA test expectations to reflect that validation swapping now starts only after the first averaging update occurs.
- Added a regression test ensuring validation observes current trained weights (not the frozen initial snapshot) when the averaging update threshold is never reached.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
src/lightning/pytorch/callbacks/weight_averaging.py |
Prevents swapping in the averaged model for validation until n_averaged > 0, avoiding evaluation on an untrained snapshot. |
tests/tests_pytorch/callbacks/test_weight_averaging.py |
Adjusts swap-count expectations and adds a regression test for “no swap before first update” behavior. |
src/lightning/pytorch/CHANGELOG.md |
Documents the fix under the unreleased “Fixed” section. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
I think it's better to use |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #21724
WeightAveraging(and itsEMAWeightAveragingsubclass) creates theAveragedModelinsetup()as a copy of the model's initial weights, withn_averaged == 0. The validation hookson_validation_epoch_start/on_validation_epoch_endswapped this average model in unconditionally whenever it existed.When using a delayed start (e.g.
EMAWeightAveraging(update_starting_at_step=1000)) and validating during the warmup period, the average model has never been updated, so validation ran against the frozen initial (untrained) weights instead of the current trained ones. This is what the issue describes as metrics being near zero beforeupdate_starting_at_step.This PR only swaps the models for validation once the average model has actually been updated at least once (
n_averaged > 0). The swap remains balanced across the start/end hooks becausen_averageddoes not change during validation.Tests
test_weight_averaging_no_swap_before_first_update, which verifies that during a never-reached delayed start the parameters seen at validation are the current trained weights, not the frozen initial snapshot. It fails before this change and passes after.SWATestCallbackswap-count expectations: with a delayed update schedule, validation now only swaps once the average model has been updated.Before submitting
PR review
Anyone in the community is welcome to review the PR.